{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Byte-Pair Encoding tokenization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [\n",
    "    \"This is the Hugging Face course.\",\n",
    "    \"This chapter is about tokenization.\",\n",
    "    \"This section shows several tokenizer algorithms.\",\n",
    "    \"Hopefully, you will be able to understand how they are trained and generate tokens.\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int, {'This': 3, 'Ġis': 2, 'Ġthe': 1, 'ĠHugging': 1, 'ĠFace': 1, 'ĠCourse': 1, '.': 4, 'Ġchapter': 1,\n",
       "    'Ġabout': 1, 'Ġtokenization': 1, 'Ġsection': 1, 'Ġshows': 1, 'Ġseveral': 1, 'Ġtokenizer': 1, 'Ġalgorithms': 1,\n",
       "    'Hopefully': 1, ',': 1, 'Ġyou': 1, 'Ġwill': 1, 'Ġbe': 1, 'Ġable': 1, 'Ġto': 1, 'Ġunderstand': 1, 'Ġhow': 1,\n",
       "    'Ġthey': 1, 'Ġare': 1, 'Ġtrained': 1, 'Ġand': 1, 'Ġgenerate': 1, 'Ġtokens': 1})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "word_freqs = defaultdict(int)\n",
    "\n",
    "for text in corpus:\n",
    "    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
    "    new_words = [word for word, offset in words_with_offsets]\n",
    "    for word in new_words:\n",
    "        word_freqs[word] += 1\n",
    "\n",
    "print(word_freqs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's',\n",
       "  't', 'u', 'v', 'w', 'y', 'z', 'Ġ']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alphabet = []\n",
    "\n",
    "for word in word_freqs.keys():\n",
    "    for letter in word:\n",
    "        if letter not in alphabet:\n",
    "            alphabet.append(letter)\n",
    "alphabet.sort()\n",
    "\n",
    "print(alphabet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = [\"<|endoftext|>\"] + alphabet.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "splits = {word: [c for c in word] for word in word_freqs.keys()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_pair_freqs(splits):\n",
    "    pair_freqs = defaultdict(int)\n",
    "    for word, freq in word_freqs.items():\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "        for i in range(len(split) - 1):\n",
    "            pair = (split[i], split[i + 1])\n",
    "            pair_freqs[pair] += freq\n",
    "    return pair_freqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('T', 'h'): 3\n",
       "('h', 'i'): 3\n",
       "('i', 's'): 5\n",
       "('Ġ', 'i'): 2\n",
       "('Ġ', 't'): 7\n",
       "('t', 'h'): 3"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pair_freqs = compute_pair_freqs(splits)\n",
    "\n",
    "for i, key in enumerate(pair_freqs.keys()):\n",
    "    print(f\"{key}: {pair_freqs[key]}\")\n",
    "    if i >= 5:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('Ġ', 't') 7"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_pair = \"\"\n",
    "max_freq = None\n",
    "\n",
    "for pair, freq in pair_freqs.items():\n",
    "    if max_freq is None or max_freq < freq:\n",
    "        best_pair = pair\n",
    "        max_freq = freq\n",
    "\n",
    "print(best_pair, max_freq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merges = {(\"Ġ\", \"t\"): \"Ġt\"}\n",
    "vocab.append(\"Ġt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_pair(a, b, splits):\n",
    "    for word in word_freqs:\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "\n",
    "        i = 0\n",
    "        while i < len(split) - 1:\n",
    "            if split[i] == a and split[i + 1] == b:\n",
    "                split = split[:i] + [a + b] + split[i + 2 :]\n",
    "            else:\n",
    "                i += 1\n",
    "        splits[word] = split\n",
    "    return splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Ġt', 'r', 'a', 'i', 'n', 'e', 'd']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "splits = merge_pair(\"Ġ\", \"t\", splits)\n",
    "print(splits[\"Ġtrained\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size = 50\n",
    "\n",
    "while len(vocab) < vocab_size:\n",
    "    pair_freqs = compute_pair_freqs(splits)\n",
    "    best_pair = \"\"\n",
    "    max_freq = None\n",
    "    for pair, freq in pair_freqs.items():\n",
    "        if max_freq is None or max_freq < freq:\n",
    "            best_pair = pair\n",
    "            max_freq = freq\n",
    "    splits = merge_pair(*best_pair, splits)\n",
    "    merges[best_pair] = best_pair[0] + best_pair[1]\n",
    "    vocab.append(best_pair[0] + best_pair[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{('Ġ', 't'): 'Ġt', ('i', 's'): 'is', ('e', 'r'): 'er', ('Ġ', 'a'): 'Ġa', ('Ġt', 'o'): 'Ġto', ('e', 'n'): 'en',\n",
       " ('T', 'h'): 'Th', ('Th', 'is'): 'This', ('o', 'u'): 'ou', ('s', 'e'): 'se', ('Ġto', 'k'): 'Ġtok',\n",
       " ('Ġtok', 'en'): 'Ġtoken', ('n', 'd'): 'nd', ('Ġ', 'is'): 'Ġis', ('Ġt', 'h'): 'Ġth', ('Ġth', 'e'): 'Ġthe',\n",
       " ('i', 'n'): 'in', ('Ġa', 'b'): 'Ġab', ('Ġtoken', 'i'): 'Ġtokeni'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(merges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<|endoftext|>', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o',\n",
       " 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'Ġ', 'Ġt', 'is', 'er', 'Ġa', 'Ġto', 'en', 'Th', 'This', 'ou', 'se',\n",
       " 'Ġtok', 'Ġtoken', 'nd', 'Ġis', 'Ġth', 'Ġthe', 'in', 'Ġab', 'Ġtokeni']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize(text):\n",
    "    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
    "    pre_tokenized_text = [word for word, offset in pre_tokenize_result]\n",
    "    splits = [[l for l in word] for word in pre_tokenized_text]\n",
    "    for pair, merge in merges.items():\n",
    "        for idx, split in enumerate(splits):\n",
    "            i = 0\n",
    "            while i < len(split) - 1:\n",
    "                if split[i] == pair[0] and split[i + 1] == pair[1]:\n",
    "                    split = split[:i] + [merge] + split[i + 2 :]\n",
    "                else:\n",
    "                    i += 1\n",
    "            splits[idx] = split\n",
    "\n",
    "    return sum(splits, [])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['This', 'Ġis', 'Ġ', 'n', 'o', 't', 'Ġa', 'Ġtoken', '.']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenize(\"This is not a token.\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Byte-Pair Encoding tokenization",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
